"""An example of finetuning Qwen-VL via Direct Preference Optimization (DPO)."""

import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from itertools import combinations
from typing import Dict, List, Optional
import wandb
from azfuse import File

import datasets
from datasets import load_dataset, concatenate_datasets
import numpy as np
import torch.distributed
import transformers
from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import GPTQConfig, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from trl.trainer import DPOTrainer
from trl.trainer.utils import DPODataCollatorWithPadding

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen-VL-Chat")


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    model_max_length: int = field(
        default=8192,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = False
    fix_vit: bool = True
    beta: float = field(default=0.1)
    generate_during_eval: bool = field(default=False)


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    image_folder: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    unk_data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    unk_image_folder: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    dataset_debug: bool = False



@dataclass
class LoraArguments:
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = field(
        default_factory=lambda: [
            "c_attn",
            "attn.c_proj",
            "w1",
            "w2",
        ]  ##["in_proj","out_proj","c_fc"]
    )
    lora_weight_path: str = ""
    lora_bias: str = "none"
    q_lora: bool = False


def maybe_zero_3(param):
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param


# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return


local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def safe_save_model_for_hf_trainer(
    trainer: transformers.Trainer, output_dir: str, bias="none"
):
    """Collects the state dict and dump to disk."""
    # check if zero3 mode enabled
    if deepspeed.is_deepspeed_zero3_enabled():
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        if trainer.args.use_lora:
            state_dict = get_peft_state_maybe_zero_3(
                trainer.model.named_parameters(), bias
            )
        else:
            state_dict = trainer.model.state_dict()
    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)
    # upload the checkpoint
    to_upload = [f for f in os.listdir(output_dir)]
    with File.async_upload(enabled=True):
        for f in to_upload:
            filepath = os.path.join(output_dir, f)
            content = open(filepath, "rb").read()
            with File.open(filepath, "wb") as fp:
                fp.write(content)


def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int,
    system_message: str = "You are a helpful assistant.",
) -> Dict:
    roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}

    im_start = tokenizer.im_start_id
    im_end = tokenizer.im_end_id
    nl_tokens = tokenizer("\n").input_ids
    _system = tokenizer("system").input_ids + nl_tokens

    # Apply prompt templates
    prompt_ids, prompt_targets = [], []
    answer_ids, answer_targets = [], []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != roles["user"]:
            source = source[1:]

        input_id, target = [], []
        system = (
            [im_start]
            + _system
            + tokenizer(system_message).input_ids
            + [im_end]
            + nl_tokens
        )
        input_id += system
        target += (
            [im_start] + [IGNORE_TOKEN_ID] * (len(system) - 3) + [im_end] + nl_tokens
        )
        # print (input_id)
        # print ("--------------------------------")
        # print (target)
        # exit(1)
        assert len(input_id) == len(target)
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            _input_id = (
                tokenizer(role).input_ids
                + nl_tokens
                + tokenizer(sentence["value"]).input_ids
                + [im_end]
                + nl_tokens
            )
            input_id += _input_id
            if role == "<|im_start|>user":
                _target = (
                    [im_start]
                    + [IGNORE_TOKEN_ID] * (len(_input_id) - 3)
                    + [im_end]
                    + nl_tokens
                )
                prompt_ids.append(input_id[:])
                prompt_targets.append((target + _target)[:])
            elif role == "<|im_start|>assistant":
                _target = (
                    [im_start]
                    + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids)
                    + _input_id[len(tokenizer(role).input_ids) + 1 : -2]
                    + [im_end]
                    + nl_tokens
                )
                answer_ids.append(_input_id[:])
                answer_targets.append(_target[:])
            else:
                raise NotImplementedError
            target += _target
        assert len(input_id) == len(target)
        assert len(prompt_ids[-1]) == len(prompt_targets[-1])
        assert len(answer_ids[-1]) == len(answer_targets[-1])

    prompt_sequence_tokens = dict(
        input_ids=prompt_ids,
        labels=prompt_targets,
        attention_mask=[
            [id != tokenizer.pad_token_id for id in ids] for ids in prompt_ids
        ],
    )
    answer_sequence_tokens = dict(
        input_ids=answer_ids,
        labels=answer_targets,
        attention_mask=[
            [id != tokenizer.pad_token_id for id in ids] for ids in answer_ids
        ],
    )

    return prompt_sequence_tokens, answer_sequence_tokens


def read_jsonl(file_path):
    """Read a JSONL file and return a list of dictionaries."""
    with File.open(file_path, "r", encoding="utf-8") as file:
        return [json.loads(line) for line in file]


def qwen_vl_prompt_format(prompt, img_paths):
    out = []
    for i, img_path in enumerate(img_paths):
        out.append(f"Picture {i + 1}: <img>{img_path}</img>\n")
    out.append(prompt.strip())
    return "".join(out)


def make_conv(prompt, answer):
    return [
        {
            "from": "user",
            "value": prompt,
        },
        {
            "from": "assistant",
            "value": answer,
        },
    ]


@dataclass
class QwenDPODataCollator(DPODataCollatorWithPadding):
    def tokenize_batch_element(
        self,
        prompt: str,
        chosen: str,
        rejected: str,
    ) -> Dict:
        """Tokenize a single batch element.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
            in case the prompt + chosen or prompt + rejected responses is/are too long. First
            we truncate the prompt; if we're still too long, we truncate the chosen/rejected.

        We also create the labels for the chosen/rejected responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}

        # format for preprocessing
        chosen_conv = make_conv(prompt, chosen)
        rejected_conv = make_conv(prompt, rejected)

        # preprocess using Qwen-VL's own method
        # note that labels are already set here
        prompt_tokens, chosen_tokens = preprocess(
            [chosen_conv], self.tokenizer, self.max_length
        )
        _, rejected_tokens = preprocess(
            [rejected_conv], self.tokenizer, self.max_length
        )
        prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()}
        chosen_tokens = {k: v[0] for k, v in chosen_tokens.items()}
        rejected_tokens = {k: v[0] for k, v in rejected_tokens.items()}

        eos_token_id = self.tokenizer.eos_token_id
        # Get indices in list prompt_tokens["input_ids"] that equals the EOS token (often 0)
        eos_indices_prompt = [
            i for i, x in enumerate(prompt_tokens["input_ids"]) if x == eos_token_id
        ]
        # attention mask these indices to eos_token_id
        new_attention_mask = [
            0 if i in eos_indices_prompt else p
            for i, p in enumerate(prompt_tokens["attention_mask"])
        ]
        prompt_tokens["attention_mask"] = new_attention_mask

        # do the same for chosen and rejected
        eos_indices_chosen = [
            i for i, x in enumerate(chosen_tokens["input_ids"]) if x == eos_token_id
        ]
        new_attention_mask_c = [
            0 if i in eos_indices_chosen else p
            for i, p in enumerate(chosen_tokens["attention_mask"])
        ]
        chosen_tokens["attention_mask"] = new_attention_mask_c

        eos_indices_rejected = [
            i for i, x in enumerate(rejected_tokens["input_ids"]) if x == eos_token_id
        ]
        new_attention_mask_r = [
            0 if i in eos_indices_rejected else p
            for i, p in enumerate(rejected_tokens["attention_mask"])
        ]
        rejected_tokens["attention_mask"] = new_attention_mask_r

        # add EOS token to end of prompt
        chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        chosen_tokens["labels"].append(self.tokenizer.eos_token_id)
        chosen_tokens["attention_mask"].append(1)

        rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
        rejected_tokens["labels"].append(self.tokenizer.eos_token_id)
        rejected_tokens["attention_mask"].append(1)

        longer_response_length = max(
            len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])
        )

        # if combined sequence is too long, truncate the prompt
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            if self.truncation_mode == "keep_start":
                prompt_tokens = {
                    k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()
                }
            elif self.truncation_mode == "keep_end":
                prompt_tokens = {
                    k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()
                }
            else:
                raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")

        # if that's still too long, truncate the response
        if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
            chosen_tokens = {
                k: v[: self.max_length - self.max_prompt_length]
                for k, v in chosen_tokens.items()
            }
            rejected_tokens = {
                k: v[: self.max_length - self.max_prompt_length]
                for k, v in rejected_tokens.items()
            }

        # Create labels
        chosen_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
        rejected_tokens = {
            k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens
        }
        chosen_tokens["labels"][: len(prompt_tokens["input_ids"])] = [
            self.label_pad_token_id
        ] * len(prompt_tokens["input_ids"])
        rejected_tokens["labels"][: len(prompt_tokens["input_ids"])] = [
            self.label_pad_token_id
        ] * len(prompt_tokens["input_ids"])

        for k, toks in {
            "chosen": chosen_tokens,
            "rejected": rejected_tokens,
            "prompt": prompt_tokens,
        }.items():
            for type_key, tokens in toks.items():
                if type_key == "token_type_ids":
                    continue
                batch[f"{k}_{type_key}"] = tokens

        batch["prompt"] = prompt
        batch["chosen"] = prompt + chosen
        batch["rejected"] = prompt + rejected
        batch["chosen_response_only"] = chosen
        batch["rejected_response_only"] = rejected

        # for k in batch:
        #     print (k, batch[k])
        #     print ("=========")
        # exit(1)

        return batch


def make_vlfeedback_paired_dataset(local_rank, dataset_path, image_folder, dataset_debug):
    import pandas as pd
    if "vlfeedback" in dataset_path:
        to_prepare = [f for f in File.list(os.path.dirname(dataset_path))]
        File.prepare(to_prepare)
        # data_files = {"train": os.path.basename(dataset_path)}
        # ds = datasets.load_dataset(os.path.join(dataset_path, "VLFeedback.py"), split="train")
        # df = pd.read_json(dataset_path, lines=True)
        feedback_data = [json.loads(line) for line in File.open(dataset_path)]
        if dataset_debug:
            feedback_data = feedback_data[:100]
        converted_data = []
        for i, d in enumerate(feedback_data):
            id = d["id"]
            img_path = f"{id}.jpg"
            converted_data.append(dict(d, img_path=img_path))
        df = pd.DataFrame(data=converted_data)
        # df["img_path"] = df["id"].apply(lambda x: f"{x}.jpg")
        # ds = load_dataset("json", data_files="test_dataset.jsonl", field="label")
        # df = df[:100]
        ds = datasets.Dataset.from_pandas(df)
    elif "hadpo" in dataset_path:
        feedback_data = json.load(File.open(dataset_path))
        if dataset_debug:
            feedback_data = feedback_data[:100]
        df = pd.DataFrame(data=feedback_data)
        # df["img_path"] = df["id"].apply(lambda x: f"{x}.jpg")
        # ds = load_dataset("json", data_files="test_dataset.jsonl", field="label")
        # df = df[:100]
        ds = datasets.Dataset.from_pandas(df)
    else:
        to_prepare = [f for f in File.list(dataset_path)]
        File.prepare(to_prepare)
        ds = datasets.load_from_disk(dataset_path)
    #     print ("We are in VLFeedback")
    # elif dataset_name == "unk":
    #     ds = datasets.load_from_disk("blob_dir/debug_output/UW/epistemic_awareness/dpo_data/ours_unk_vqa_train")
    #     print ("We are in UNK")

    # format prompt
    # to_prepare = [f for f in image_folder]
    # File.prepare(to_prepare)

    if local_rank > 0:
        print("Waiting for main process to perform the mapping")
        torch.distributed.barrier()

    def set_format(sample):
        prompt = sample["prompt"]
        img_path = os.path.join(image_folder, sample["img_path"])
        File.prepare([img_path])
        sample["prompt"] = qwen_vl_prompt_format(prompt, [img_path])
        return sample

    def set_format_hadpo(sample):
        # prompt = sample["prompt"]
        img_path = os.path.join(image_folder, sample["image"])
        File.prepare([img_path])
        # sample["prompt"] = qwen_vl_prompt_format(prompt, [img_path])
        return sample

    if "hadpo" not in dataset_path:
        ds = ds.map(set_format)
    else:
        ds = ds.map(set_format_hadpo)

    if local_rank == 0:
        print("Loading results from main process")
        torch.distributed.barrier()

    # make comparison pairs from completion list
    if local_rank > 0:
        print("Waiting for main process to perform the mapping")
        torch.distributed.barrier()

    def make_batch_pairs(sample):
        converted_sample = defaultdict(list)

        for sample_idx, comps in enumerate(sample["completions"]):
            prompt = sample["prompt"][sample_idx]
            # print(comps)
            for comp_idx1, comp_idx2 in combinations(range(len(comps)), 2):
                anno1, anno2 = comps[comp_idx1]["annotations"], comps[comp_idx2]["annotations"]

                # get average scores
                try:
                    avg_score1 = np.mean(
                        [
                            float(anno1[aspect]["Rating"])
                            for aspect in anno1
                        ]
                    )
                    avg_score2 = np.mean(
                        [
                            float(anno2[aspect]["Rating"])
                            for aspect in anno2
                        ]
                    )
                except ValueError:
                    continue

                # get chosen and rejected responses
                if avg_score1 > avg_score2:
                    chosen = comps[comp_idx1]["response"]
                    rejected = comps[comp_idx2]["response"]
                elif avg_score2 > avg_score1:
                    chosen = comps[comp_idx2]["response"]
                    rejected = comps[comp_idx1]["response"]
                else:
                    continue
                converted_sample["prompt"].append(prompt)
                converted_sample["chosen"].append(chosen)
                converted_sample["rejected"].append(rejected)
        return converted_sample


    def make_hadpo_batch_pairs(sample):
        # print(sample)
        converted_sample = defaultdict(list)
        for sample_idx, image in enumerate(sample["image"]):
            # print(comps)
            img_path = os.path.join(image_folder, image)
            chosen_conversations = sample["chosen_conversations"][sample_idx]
            reject_conversations = sample["reject_conversations"][sample_idx]
            assert chosen_conversations[0]["value"] == reject_conversations[0]["value"]
            prompt = chosen_conversations[0]["value"].replace("<image>", "").strip()
            prompt = qwen_vl_prompt_format(prompt, [img_path])
            chosen = chosen_conversations[1]["value"]
            rejected = reject_conversations[1]["value"]

            converted_sample["prompt"].append(prompt)
            converted_sample["chosen"].append(chosen)
            converted_sample["rejected"].append(rejected)
        return converted_sample

    if "vlfeedback" in dataset_path:
        print ("We are in vlfeedback")
        ds = ds.map(
            make_batch_pairs,
            batched=True,
            remove_columns=set(ds.column_names) - set(["prompt", "chosen", "rejected"]),
        )
    elif "hadpo" in dataset_path:
        print ("We are in hadpo")
        ds = ds.map(
            make_hadpo_batch_pairs,
            batched=True,
            remove_columns=set(ds.column_names) - set(["prompt", "chosen", "rejected"]),
        )
    else:
        print ("We are in UNK")
        ds = ds.remove_columns(set(ds.column_names) - set(["prompt", "chosen", "rejected"]))

    if local_rank == 0:
        print("Loading results from main process")
        torch.distributed.barrier()
    return ds


def env_init(distributed=True):
    print("Init Env for Distributed Training")
    if distributed:
        if 'OMPI_COMM_WORLD_SIZE' in os.environ:
            os.environ['MASTER_ADDR'] = os.environ.get("MASTER_ADDR", 'localhost')
            os.environ['MASTER_PORT']  = os.environ.get("MASTER_PORT", "12875")
            os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
            os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
            os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
        elif 'WORLD_SIZE' in os.environ:
            os.environ['MASTER_ADDR'] = os.environ.get("MASTER_ADDR", 'localhost')
            os.environ['MASTER_PORT']  = os.environ.get("MASTER_PORT", "12875")
        else:
            return
    else:
        return


def train():
    env_init()
    global local_rank

    os.environ["WANDB_PROJECT"] = "Silkie"
    parser = transformers.HfArgumentParser(
        (ModelArguments, TrainingArguments, LoraArguments, DataArguments)
    )
    (
        model_args,
        training_args,
        lora_args,
        data_args,
    ) = parser.parse_args_into_dataclasses()
    

    if File.isfile(os.path.join(model_args.model_name_or_path, "tokenizer_config.json")):
        to_prepare = [f for f in File.list(model_args.model_name_or_path)]
        File.prepare(to_prepare)
    else:
        print(f"Missing from blob, {model_args.model_name_or_path}")

    if getattr(training_args, "deepspeed", None) and getattr(
        lora_args, "q_lora", False
    ):
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    local_rank = training_args.local_rank

    device_map = None
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if lora_args.q_lora:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            logging.warning("FSDP or ZeRO3 are not incompatible with QLoRA.")

    # Set RoPE scaling factor
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True,
        fp32=True,
    )
    config.use_cache = False

    # Load model and tokenizer
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        trust_remote_code=True,
        quantization_config=GPTQConfig(bits=4, disable_exllama=True)
        if training_args.use_lora and lora_args.q_lora
        else None,
    )

    if not training_args.use_lora:
        if (
            training_args.fix_vit
            and hasattr(model, "transformer")
            and hasattr(model.transformer, "visual")
        ):
            model.transformer.visual.requires_grad_(False)
            if hasattr(model.transformer.visual, "attn_pool"):
                model.transformer.visual.attn_pool.requires_grad_(True)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
    )
    tokenizer.pad_token_id = tokenizer.eod_id
    tokenizer.eos_token_id = tokenizer.eod_id

    if training_args.use_lora:
        if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower():
            modules_to_save = None
            print ("NOT Inside LoRA")
        else:
            print ("Inside LoRA")
            modules_to_save = ["wte", "lm_head"]
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_args.lora_target_modules,
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,
            task_type="CAUSAL_LM",
            modules_to_save=modules_to_save,  # This argument serves for adding new tokens.
        )
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=training_args.gradient_checkpointing
            )

        if training_args.gradient_checkpointing:
            model.enable_input_require_grads()

    if data_args.data_path is not None and data_args.image_folder is not None:
        # Load data
        # dataset = make_vlfeedback_paired_dataset(training_args.local_rank, "VLFeedback")
        dataset = make_vlfeedback_paired_dataset(training_args.local_rank, data_args.data_path, data_args.image_folder, dataset_debug=data_args.dataset_debug)
        dataset_split = dataset.train_test_split(test_size=0.005, seed=42)
        train_dataset = dataset_split["train"]
        eval_dataset = dataset_split["test"]
        print(f"train: {len(train_dataset)}, eval: {len(eval_dataset)}")
    else:
        train_dataset = None
        eval_dataset = None

    if data_args.unk_data_path is not None and data_args.unk_image_folder is not None:
        unk_dataset = make_vlfeedback_paired_dataset(training_args.local_rank, data_args.unk_data_path, data_args.unk_image_folder)
        unk_dataset_split = unk_dataset.train_test_split(test_size=0.005, seed=42)
        unk_train_dataset = unk_dataset_split["train"]
        unk_eval_dataset = unk_dataset_split["test"]
        print(f"train: {len(unk_train_dataset)}, eval: {len(unk_eval_dataset)}")
    else:
        unk_train_dataset = None
        unk_eval_dataset = None

    if train_dataset is None:
        train_dataset = unk_train_dataset
        eval_dataset = unk_eval_dataset
    elif unk_train_dataset is not None:
        train_dataset = concatenate_datasets([train_dataset, unk_train_dataset])
        eval_dataset = concatenate_datasets([eval_dataset, unk_eval_dataset])
    print("train_dataset:", len(train_dataset))
    # Start trainner
    trainer = DPOTrainer(
        model,
        args=training_args,
        beta=training_args.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=QwenDPODataCollator(
            tokenizer,
            max_length=training_args.model_max_length,
            max_prompt_length=training_args.model_max_length // 2,
            max_target_length=training_args.model_max_length // 2,
            label_pad_token_id=IGNORE_TOKEN_ID,
            padding_value=tokenizer.pad_token_id,
            truncation_mode="keep_end",
        ),
        tokenizer=tokenizer,
        max_length=training_args.model_max_length,
        peft_config=lora_config if training_args.use_lora else None,
        generate_during_eval=training_args.generate_during_eval,
    )

    trainer.train()
    trainer.save_state()

    safe_save_model_for_hf_trainer(
        trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias
    )


if __name__ == "__main__":
    train()
